{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "06-mnist-tpu-training.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WsWdLFMVKqbi"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-tpu-training.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qXO1QLkbRXl0"
      },
      "source": [
        "# TPU training with PyTorch Lightning ⚡\n",
        "\n",
        "In this notebook, we'll train a model on TPUs. Changing one line of code is all you need to that.\n",
        "\n",
        "The most up to documentation related to TPU training can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/tpu.html).\n",
        "\n",
        "---\n",
        "\n",
        "  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
        "  - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
        "  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
        "  - Ask a question on our [official forum](https://forums.pytorchlightning.ai/)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UmKX0Qa1RaLL"
      },
      "source": [
        "### Setup\n",
        "\n",
        "Lightning is easy to install. Simply ```pip install pytorch-lightning```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vAWOr0FZRaIj"
      },
      "source": [
        "! pip install pytorch-lightning -qU"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zepCr1upT4Z3"
      },
      "source": [
        "###  Install Colab TPU compatible PyTorch/TPU wheels and dependencies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AYGWh10lRaF1"
      },
      "source": [
        "! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SNHa7DpmRZ-C"
      },
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import random_split, DataLoader\n",
        "\n",
        "# Note - you must have torchvision installed for this example\n",
        "from torchvision.datasets import MNIST\n",
        "from torchvision import transforms\n",
        "\n",
        "import pytorch_lightning as pl\n",
        "from pytorch_lightning.metrics.functional import accuracy"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rjo1dqzGUxt6"
      },
      "source": [
        "### Defining The `MNISTDataModule`\n",
        "\n",
        "Below we define `MNISTDataModule`. You can learn more about datamodules in [docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html) and [datamodule notebook](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pkbrm3YgUxlE"
      },
      "source": [
        "class MNISTDataModule(pl.LightningDataModule):\n",
        "\n",
        "    def __init__(self, data_dir: str = './'):\n",
        "        super().__init__()\n",
        "        self.data_dir = data_dir\n",
        "        self.transform = transforms.Compose([\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((0.1307,), (0.3081,))\n",
        "        ])\n",
        "\n",
        "        # self.dims is returned when you call dm.size()\n",
        "        # Setting default dims here because we know them.\n",
        "        # Could optionally be assigned dynamically in dm.setup()\n",
        "        self.dims = (1, 28, 28)\n",
        "        self.num_classes = 10\n",
        "\n",
        "    def prepare_data(self):\n",
        "        # download\n",
        "        MNIST(self.data_dir, train=True, download=True)\n",
        "        MNIST(self.data_dir, train=False, download=True)\n",
        "\n",
        "    def setup(self, stage=None):\n",
        "\n",
        "        # Assign train/val datasets for use in dataloaders\n",
        "        if stage == 'fit' or stage is None:\n",
        "            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
        "            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
        "\n",
        "        # Assign test dataset for use in dataloader(s)\n",
        "        if stage == 'test' or stage is None:\n",
        "            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
        "\n",
        "    def train_dataloader(self):\n",
        "        return DataLoader(self.mnist_train, batch_size=32)\n",
        "\n",
        "    def val_dataloader(self):\n",
        "        return DataLoader(self.mnist_val, batch_size=32)\n",
        "\n",
        "    def test_dataloader(self):\n",
        "        return DataLoader(self.mnist_test, batch_size=32)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nr9AqDWxUxdK"
      },
      "source": [
        "### Defining the `LitModel`\n",
        "\n",
        "Below, we define the model `LitMNIST`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YKt0KZkOUxVY"
      },
      "source": [
        "class LitModel(pl.LightningModule):\n",
        "    \n",
        "    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n",
        "\n",
        "        super().__init__()\n",
        "\n",
        "        self.save_hyperparameters()\n",
        "\n",
        "        self.model = nn.Sequential(\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(channels * width * height, hidden_size),\n",
        "            nn.ReLU(),\n",
        "            nn.Dropout(0.1),\n",
        "            nn.Linear(hidden_size, hidden_size),\n",
        "            nn.ReLU(),\n",
        "            nn.Dropout(0.1),\n",
        "            nn.Linear(hidden_size, num_classes)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.model(x)\n",
        "        return F.log_softmax(x, dim=1)\n",
        "\n",
        "    def training_step(self, batch, batch_idx):\n",
        "        x, y = batch\n",
        "        logits = self(x)\n",
        "        loss = F.nll_loss(logits, y)\n",
        "        self.log('train_loss', loss, prog_bar=False)\n",
        "        return loss\n",
        "\n",
        "    def validation_step(self, batch, batch_idx):\n",
        "        x, y = batch\n",
        "        logits = self(x)\n",
        "        loss = F.nll_loss(logits, y)\n",
        "        preds = torch.argmax(logits, dim=1)\n",
        "        acc = accuracy(preds, y)\n",
        "        self.log('val_loss', loss, prog_bar=True)\n",
        "        self.log('val_acc', acc, prog_bar=True)\n",
        "        return loss\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)\n",
        "        return optimizer"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uxl88z06cHyV"
      },
      "source": [
        "### TPU Training\n",
        "\n",
        "Lightning supports training on a single TPU core or 8 TPU cores.\n",
        "\n",
        "The Trainer parameters `tpu_cores` defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1].\n",
        "\n",
        "For Single TPU training, Just pass the TPU core ID [1-8] in a list. Setting `tpu_cores=[5]` will train on TPU core ID 5."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UZ647Xg2gYng"
      },
      "source": [
        "Train on TPU core ID 5 with `tpu_cores=[5]`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bzhJ8g_vUxN2"
      },
      "source": [
        "# Init DataModule\n",
        "dm = MNISTDataModule()\n",
        "# Init model from datamodule's attributes\n",
        "model = LitModel(*dm.size(), dm.num_classes)\n",
        "# Init trainer\n",
        "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])\n",
        "# Train\n",
        "trainer.fit(model, dm)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "slMq_0XBglzC"
      },
      "source": [
        "Train on single TPU core with `tpu_cores=1`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "31N5Scf2RZ61"
      },
      "source": [
        "# Init DataModule\n",
        "dm = MNISTDataModule()\n",
        "# Init model from datamodule's attributes\n",
        "model = LitModel(*dm.size(), dm.num_classes)\n",
        "# Init trainer\n",
        "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1)\n",
        "# Train\n",
        "trainer.fit(model, dm)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_v8xcU5Sf_Cv"
      },
      "source": [
        "Train on 8 TPU cores with `tpu_cores=8`. You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EFEw7YpLf-gE"
      },
      "source": [
        "# Init DataModule\n",
        "dm = MNISTDataModule()\n",
        "# Init model from datamodule's attributes\n",
        "model = LitModel(*dm.size(), dm.num_classes)\n",
        "# Init trainer\n",
        "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)\n",
        "# Train\n",
        "trainer.fit(model, dm)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m2mhgEgpRZ1g"
      },
      "source": [
        "<code style=\"color:#792ee5;\">\n",
        "    <h1> <strong> Congratulations - Time to Join the Community! </strong>  </h1>\n",
        "</code>\n",
        "\n",
        "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
        "\n",
        "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
        "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
        "\n",
        "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
        "\n",
        "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
        "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
        "\n",
        "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
        "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
        "\n",
        "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
        "\n",
        "### Contributions !\n",
        "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
        "\n",
        "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
        "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
        "* You can also contribute your own notebooks with useful examples !\n",
        "\n",
        "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
        "\n",
        "<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />"
      ]
    }
  ]
}
